#utils for contribution-vote-prediction analysis
from zipfile import ZipFile
import pandas as pd
import functools
import numpy as np
import joblib
import copy
from sklearn.base import BaseEstimator
from collections import Counter, defaultdict
import json
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import accuracy_score
from scipy.stats import kendalltau
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns

propublica_api_key = None # needed to query the propublica congressional records dataset

def arr2str(array):
  try:
    return 'np.array(['+','.join([ '%.15f' % f for f in array])+'])'
  except:
    return 'np.NaN'

def load_csv(fname, **kwargs):
  df = pd.read_csv(fname, **kwargs)
  for col in df.keys():
    try:
      col_ = df[col].apply(eval)
      df[col] = col_
    except: pass
  return df

def to_csv(df, fname, arr2str_cols=lambda c: c.startswith('embedding_')):
  df_ = df #df.copy()
  print('processing columns')
  for col in df_.keys():
    if arr2str_cols(col):
      df_[col] = df_[col].apply(arr2str)
  df_.to_csv(fname, index=False)

def read_zipped_dfs(fname, low_memory=False, _filter=lambda x: x):
  zip_file = ZipFile(fname)
  for i, text_file in enumerate([ text_file for text_file in zip_file.infolist() \
                                  if text_file.filename.endswith('.csv') ]):
    new_df = _filter(pd.read_csv(zip_file.open(text_file.filename), low_memory=low_memory ))
    if i == 0: df = new_df
    else:  df = pd.concat([df, new_df])
  #dfs = {text_file.filename: pd.read_csv(zip_file.open(text_file.filename), \
  #       low_memory=low_memory ) for text_file in zip_file.infolist() \
  #       if text_file.filename.endswith('.csv')}
  #df = pd.concat(list(dfs.values()))
  return df

def load(fname):
  return joblib.load(fname)

class Legislator(object):
  #generic class to record Legislator attributes
  #can be extended with arbitrary attributes either individually 
  #or by wrapper class LegislatorsDataset
  def __init__(self, **kwargs):
    self.all_attrs = kwargs
    for kw, arg in kwargs.items():
      if kw != 'Text' and type(arg) == str:
        try: arg = eval(arg)
        except: pass
      
      if kw == 'Datetime':	#add a Date field
        setattr(self, 'Date', arg.split()[0])
      setattr(self, kw, arg)

class LegislatorsDataset(object):
  def __init__(self, _legislators=None, namefield='Name', **kwargs):
    #if _legislators kwarg is passed, ignore remaining kwargs
    super(LegislatorsDataset, self).__init__()
    keys = list(kwargs.keys())
    val_lists = list(kwargs.values())
    self.namefield = namefield
    if _legislators is None:
      _legislators = []
      for vals in zip(*val_lists):
        legislator_kwargs = { keys[i]: val for i, val in enumerate(vals) }
        _legislators.append(Legislator(**legislator_kwargs))
    
    self._legislators = _legislators
    
  def __getattr__(self, attr_name):
    try:
      return vars(self)[attr_name]
    except KeyError:
      return self.__collectattrs__(attr_name)
  
  def __setattr__(self, attr_name, val):
    if '_legislators' in vars(self): #checks if dataset has been initialized
      if attr_name == '_legislators':
        super(LegislatorsDataset, self).__setattr__(attr_name, val)
      else:
        assert len(val) == len(self._legislators), \
               "length of any new dataset attribute must match number of tweets"
        for v, legislator in zip(val, self._legislators):
          setattr(legislator, attr_name, v)
    else:
      #if uninitialized, use default attribute-setting behavior
      super(LegislatorsDataset, self).__setattr__(attr_name, val)
  
  def __getitem__(self, i):
    if isinstance(i, int):
      return self._legislators[i]
    elif isinstance(i, slice):
      slice_legislators = [ copy.copy(tweet) for tweet in 
                        self._legislators[i.start:i.stop:i.step] ]
      return LegislatorsDataset(_legislators=sliced_legislators, 
                                namefield=self.namefield)
    elif isinstance(i, str):
      return self._legislators_dict[i]
    else:
      try:
        enum_legislators = [ copy.copy(self[j]) for j in i ]
        return LegislatorsDataset(_legislators=enum_legislators, 
                                  namefield=self.namefield)
      except TypeError:
        print('cannot index with %s' % str(type(i)))
        raise TypeError
    
  @property
  def _legislators_dict(self, namefield=None):
    if namefield is None: namefield = self.namefield
    return { getattr(legislator, namefield): legislator for legislator in self }
  
  def __len__(self):
    return len(self._legislators)
  
  def __collectattrs__(self, attr_name):
    all_of_attr = []; any_match = False
    for legislator in self._legislators:
      try: 
        all_of_attr.append(getattr(legislator, attr_name))
        any_match = True
      except AttributeError: 
        all_of_attr.append(None)
    if not any_match:
      if attr_name != '__getstate__':
        print("no legislator has this attribute \"%s\"" % attr_name)
      raise AttributeError
    # if the attribute is array-able, return as an array
    try:
      return np.array(all_of_attr)
    except:
      return all_of_attr
  
  def __getstate__(self):
    return vars(self)
  
  def __setstate__(self, state):
    vars(self).update(state)
  
  def __iadd__(self, b):
    assert isinstance(b, LegislatorsDataset)
    self._legislators += b._legislators
    return self
  
  def __add__(self, b):
    assert isinstance(b, LegislatorsDataset)
    _cat_legislators = self._legislators + b._legislators
    return LegislatorsDataset(_legislators=_cat_legislators)
  
  def save(self, fname):
    if not fname.endswith('.lds'):
      fname += '.lds' #.lds as canonical "LegislatorsDataset" file extension
    joblib.dump(self, fname)
  
  def apply_filter(self, boolean_filter):
    #boolean_filter is a boolean array or a boolean function
    #e.g. boolean_filter = lambda tweet: True returns the original ds
    if type(boolean_filter) == type(lambda f: f):
      out_legislators = [ copy.copy(legislator) for legislator \
                          in self._legislators if boolean_filter(legislator) ]
    else:
      out_legislators = [ copy.copy(legislator) for (b,legislator) in 
                           zip(boolean_filter, self._legislators) if b ]
    return LegislatorsDataset(_legislators=out_legislators)
  
  def sample(self, sample_size):
    perm = np.random.permutation(len(self))
    return self[perm[:sample_size]]


class ConditionalMajorityVote(BaseEstimator):
  def __init__(self):
    pass
    self.params = {}
  
  def get_params(self, deep=True):
    return self.params
  
  def fit(self, X_, y):
    categories_and_votes = Counter([ (x, y_) for x, y_ in zip(X_, y) ])
    categories = []; labels = []
    votes_by_cat = defaultdict(dict)
    for (x, y_), count in categories_and_votes.items():
      if x not in categories: categories.append(x)
      if y_ not in labels: labels.append(y_)
      votes_by_cat[x][y_] = count
    self.categories = sorted(categories)
    self.labels = sorted(labels)
    self.majorityvote = { x: max(votes, key=lambda label: votes[label]) \
                                        for x, votes in votes_by_cat.items() }
  
  def _predict_label(self, x):
    try:
      return self.majorityvote[x]
    except KeyError:
      return np.random.choice(self.labels)
  
  def predict(self, X):
    return np.array([ self._predict_label(x) for x in X ])
  
  def score(self, X, y):
    y_ = np.array(y)
    prediction = self.predict(X)
    return (prediction == y_).sum()/y_.shape[0]

def permute_data(*args):
  perm = np.random.permutation(np.array(args[0]).shape[0])
  out = tuple( np.array(arg)[perm] for arg in args )
  if len(out) == 1: return out[0]
  else: return out

def party_indicator(parties):
  party2id = {'Democrat': 0, 'Republican': 1, 'Independent': 2, 'Libertarian': 2}
  out = np.zeros((parties.shape[0], 3))
  for i, p in enumerate(parties):
    out[i,party2id[p]] = 1.
  return out

def json2df(json_fname):
  with open(json_fname, 'r') as f:
    data = json.load(f)
  fields = []
  for item in data:
    for field in item:
      if field not in fields:
        fields.append(field)
  Data = { key: [] for key in fields }
  for item in data:
    for field in fields:
      try:
        entry = item[field]
        if type(entry) == list: entry = tuple(entry) 
        Data[field].append(entry)
      except KeyError: Data[field].append(np.NaN)
  return pd.DataFrame(Data)


def make_indicator(groups):
  group_ids = {g: i for i, g in enumerate(np.unique(groups)) }
  indicator = []
  for g in groups:
    samp_indicator = np.zeros(len(group_ids))
    samp_indicator[group_ids[g]] = 1.
    indicator.append(samp_indicator)
  return np.array(indicator)

def predict_and_score(*args, score_metric=accuracy_score, **kwargs):
  predictions = cross_val_predict(*args, **kwargs)
  X, y = args[1], args[2]
  return predictions, score_metric(y, predictions)

def tau(x,y):
  return kendalltau(x, y)[0]

def tau_partial(x,y,z): 
  #z is the covariate
  #tau of x and y controlling for z
  numerator = (tau(x,y) - tau(x,z)*tau(z,y)) 
  denominator = np.sqrt((1-tau(x,z)**2))*np.sqrt((1-tau(y,z)**2))
  return numerator/denominator

def get_triu(arr, offset=1):
  return arr[np.triu_indices(arr.shape[0], offset)]


class SeabornFig2Grid():
    #from stackoverflow respondent
    #https://stackoverflow.com/questions/35042255/how-to-plot-multiple-seaborn-jointplot-in-subplot
    def __init__(self, seaborngrid, fig,  subplot_spec):
        self.fig = fig
        self.sg = seaborngrid
        self.subplot = subplot_spec
        if isinstance(self.sg, sns.axisgrid.FacetGrid) or \
            isinstance(self.sg, sns.axisgrid.PairGrid):
            self._movegrid()
        elif isinstance(self.sg, sns.axisgrid.JointGrid):
            self._movejointgrid()
        self._finalize()
    
    def _movegrid(self):
        """ Move PairGrid or Facetgrid """
        self._resize()
        n = self.sg.axes.shape[0]
        m = self.sg.axes.shape[1]
        self.subgrid = gridspec.GridSpecFromSubplotSpec(n,m, subplot_spec=self.subplot)
        for i in range(n):
            for j in range(m):
                self._moveaxes(self.sg.axes[i,j], self.subgrid[i,j])
    
    def _movejointgrid(self):
        """ Move Jointgrid """
        h= self.sg.ax_joint.get_position().height
        h2= self.sg.ax_marg_x.get_position().height
        r = int(np.round(h/h2))
        self._resize()
        self.subgrid = gridspec.GridSpecFromSubplotSpec(r+1,r+1, subplot_spec=self.subplot)
    
        self._moveaxes(self.sg.ax_joint, self.subgrid[1:, :-1])
        self._moveaxes(self.sg.ax_marg_x, self.subgrid[0, :-1])
        self._moveaxes(self.sg.ax_marg_y, self.subgrid[1:, -1])
    
    def _moveaxes(self, ax, gs):
        #https://stackoverflow.com/a/46906599/4124317
        ax.remove()
        ax.figure=self.fig
        self.fig.axes.append(ax)
        self.fig.add_axes(ax)
        ax._subplotspec = gs
        ax.set_position(gs.get_position(self.fig))
        ax.set_subplotspec(gs)
    
    def _finalize(self):
        plt.close(self.sg.fig)
        self.fig.canvas.mpl_connect("resize_event", self._resize)
        self.fig.canvas.draw()
    
    def _resize(self, evt=None):
        self.sg.fig.set_size_inches(self.fig.get_size_inches())

def make_rank_series(x):
  x = np.array(x)
  #ranks = x.argsort().argsort()
  x_vals, x_counts = np.unique(x, return_counts=True)
  if x_vals.shape[0] == x.shape[0]:
    ranks = x.argsort().argsort()
    return ranks
  else:
    x_cum_counts = np.hstack([ [0], x_counts.cumsum() ])
    val2rank_ = { v: (x_cum_counts[i] + x_cum_counts[i+1])/2 
                                  for i, v in enumerate(x_vals) }
    val2rank = np.vectorize(lambda v: val2rank_[v])
    return val2rank(x)





